{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "63685727",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "from lee_et_al_2023.src import base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28367ee7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "pc_df = pd.read_csv(base.DATA_PATH / 'fig_1_pc_data.csv')\n",
    "\n",
    "pca_gnn = pc_df[['gnn_pc1', 'gnn_pc2']].values\n",
    "pca_fp = pc_df[['fp_pc1', 'fp_pc2']].values\n",
    "pca_label = pc_df[['label_pc1', 'label_pc2']].values\n",
    "pc_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4949112",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def plot_islands(bg_points,\n",
    "                 bg_name,\n",
    "                 *fg_specs,\n",
    "                 alpha = 0.4,\n",
    "                 z_limit = 15,\n",
    "                 colors=None):\n",
    "    \"\"\"Plot islands in a background sea of points.\n",
    "\n",
    "    This command executes matplotlib.pyplot commands as side effects.\n",
    "    Use plt.figure() to control where these outputs get generated.\n",
    "\n",
    "    Args:\n",
    "    bg_points: Array of shape (n_points, 2) indicating x,y coordinates\n",
    "        of any background points.\n",
    "    bg_name: Name of background points\n",
    "    *fg_specs: Repeated island specifications. Allowable dict keys:\n",
    "        data - array of shape (n_points, 2)\n",
    "        label - str, name of the island\n",
    "        color - RGB, color of the island\n",
    "        scatter - bool, whether to add the point scatters\n",
    "        scatter_size - int, radius of scatter points in pixel\n",
    "        filled - bool, whether to fill the island\n",
    "        levels_to_plot - List[float], percentiles of the KDE to plot.\n",
    "    alpha: Transparency of island color fill.\n",
    "    z_limit: Plotting boundaries of the plot.\n",
    "    \"\"\"\n",
    "    default_fg_colors = colors or sns.color_palette('Set3', len(fg_specs))\n",
    "\n",
    "    sns.scatterplot(x=bg_points[:, 0], y=bg_points[:, 1],\n",
    "              s=3, color='0.60', label=bg_name)\n",
    "    for fg_color, fg_spec in zip(default_fg_colors, fg_specs):\n",
    "        fg_color = fg_spec.get('color', fg_color)\n",
    "        fg_scatter = fg_spec.get('scatter', False)\n",
    "        fg_scatter_size = fg_spec.get('scatter_size', 5)\n",
    "        fg_filled = fg_spec.get('filled', False)\n",
    "        fg_level_to_plot = fg_spec.get('level_to_plot', 0.25)\n",
    "        x, y = fg_spec['data'][:, 0], fg_spec['data'][:, 1]\n",
    "        label = fg_spec['label']\n",
    "        if fg_scatter:\n",
    "            sns.scatterplot(x, y, s=fg_scatter_size, color=fg_color)\n",
    "        if fg_filled:\n",
    "            sns.kdeplot(x=x, y=y, color=fg_color, fill=True,\n",
    "                      thresh=fg_level_to_plot, alpha=alpha, levels=2, bw_method=0.3)\n",
    "        else:\n",
    "            sns.kdeplot(x=x, y=y, color=fg_color, fill=False,\n",
    "                      thresh=fg_level_to_plot, levels=2, bw_method=0.3)\n",
    "        # Generate the legend entry\n",
    "        if fg_filled:\n",
    "            plt.scatter([], [], marker='s', c=fg_color, label=label)\n",
    "        else:\n",
    "            plt.plot([], [], c=fg_color, linewidth=3, label=label)\n",
    "    plt.xlim([-z_limit - 0.6, z_limit + 0.6])\n",
    "    plt.ylim([-z_limit - 0.6, z_limit + 0.6])\n",
    "    plt.gca().set_aspect('equal', adjustable='box')\n",
    "    \n",
    "\n",
    "def plot_odor_islands(pca_space, z_limit=15):\n",
    "    plt.figure(figsize=(12, 8))\n",
    "    color_palette = sns.color_palette('Set3', 15)\n",
    "    fg_specs = []\n",
    "    i = 0\n",
    "    for main_group, subgroups in [('floral', ['muguet', 'lavender', 'jasmin']),\n",
    "                                ('meaty', ['savory', 'beefy', 'roasted']),\n",
    "                                ('alcoholic', ['cognac', 'fermented', 'winey']),\n",
    "                                ]:\n",
    "        main_embeddings = pca_space[pc_df[main_group]]\n",
    "        fg_specs.append({'data': main_embeddings,\n",
    "                        'filled': True,\n",
    "                        'scatter': False,\n",
    "                        'level_to_plot': 0.1,\n",
    "                        'label': main_group.capitalize()})\n",
    "\n",
    "        for subgroup in subgroups:\n",
    "            island_embeddings = pca_space[pc_df[subgroup]]\n",
    "            fg_specs.append({'data': island_embeddings,\n",
    "                          'filled': False,\n",
    "                          'scatter': False,\n",
    "                          'level_to_plot': 0.2,\n",
    "                          'label': subgroup.capitalize()})\n",
    "        plot_islands(pca_space, None, *fg_specs, colors=color_palette[i:i+4], z_limit=z_limit)\n",
    "        i += 5\n",
    "        fg_specs = []\n",
    "        plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))\n",
    "\n",
    "plot_odor_islands(pca_gnn, z_limit=15)\n",
    "plt.title('GNN Embeddings')\n",
    "plot_odor_islands(pca_fp, z_limit=10)\n",
    "plt.title('Fingerprints')\n",
    "plot_odor_islands(pca_label, z_limit=8)\n",
    "plt.title('True Labels')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "formats": "auto:light,ipynb"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
